from Utils import *

df = pd.read_csv('../Social Bias Probing/SBIC-Pro.csv')
df['probe'] = df['probe'].apply(lambda x: x.capitalize() if isinstance(x, str) else x)
input_texts = df['probe'].tolist() # wrap up in quotes

start_time = time.time()

perplexity = load("perplexity", module_type="metric") 
PPL = {}
batch_perplexities_dict = {LM: [] for LM in LMs}
batch_size = 10000
for LM in LMs:
    for i in range(0, len(input_texts), batch_size):
        input_text_batch = input_texts[i:i + batch_size]
        batch_perplexities = perplexity.compute(model_id=LM, predictions=input_text_batch)
        batch_perplexities_dict[LM].extend(batch_perplexities['perplexities'])
        LM_filename = LM.replace('/', '-')
        print('Saved ' + str(i))
        np.save(f'../Social Bias Probing/batch_perplexities_{LM_filename}.npy', np.array(batch_perplexities_dict[LM]))
    PPL[LM] = [round(x, 3) for x in batch_perplexities_dict[LM]]
    print('\n\n\n\n <----------------------> END of ' + LM + '\n\n\n\n')

df_w_PPL = pd.concat([df, pd.DataFrame(PPL)], axis=1)
new_order = ['id', 'category', 'target', 'identity', 'stereotype', 'probe'] + LMs  
df_w_PPL = df_w_PPL[new_order]  
df_w_PPL = df_w_PPL.rename(columns=LMs_columns_names)
df_w_PPL.to_csv('../Social Bias Probing/SBIC-Pro-w-PPLs.csv', index=False)
print(df_w_PPL)

print(round(time.time() - start_time, 3))

print('<----------------------> END!')